from BasicLibraries import *
import regression as rg
from matplotlib import rc

plt.rcParams['xtick.major.pad']='8'
rc('text', usetex=True)
rc('font', size=24)
rc('legend', fontsize=24)
rc('text.latex', preamble=r'\usepackage{cmbright}')
rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})

color = ['#e95057', '#0080c4', '#00a989', '#464646']


#%% Effect of diversification
def estimate_effect_of_effort(dt_ws, dummies, YFE, sdgs_options_name, diversification_type):
    tab = pd.DataFrame()
    res_bar = pd.DataFrame()
    print(sdgs_options_name,'-', diversification_type)
    count=0
    list_of_models = dict()
    list_of_data = dict()

    for dep_var in ['DirectControl', 'Emission_intensity', 'DirectControl_fut_tot', 'Emission_intensity_fut_tot']:
        #==
        res, resNUM, full_table, x, models, baseline = rg.make_effort_regression(dt_ws, dummies,  emission_var = dep_var,
                                                    years_FE=YFE)
        res.index = [dep_var]
        resNUM[3] = (resNUM[2]-resNUM[1])/2
        resNUM.columns = ['Effect', 'LB', 'UB', 'Uncertainty']
        resNUM.index = res.columns
        resNUM = resNUM.rename(index = {'Risk mitigation': 'Risk Mit.', 'Stakeholders engagement': 'Stak.Eng.'})
        resNUM['Emission type'] = [dep_var]*len(resNUM)
        res_bar = pd.concat((res_bar, resNUM))
        count+=1
    
        tab = pd.concat((tab, res))
        list_of_models[dep_var]  = models
        list_of_data[dep_var] = x
    
    tab.loc['Assets\' characteristics'] = 'Yes'
    tab.loc['Sector fixed effects']     = 'Yes'
    tab.loc['Country fixed effects']    = 'Yes'
    tab.loc['Year fixed effects']       = ['Yes' if YFE == True else 'No'][0]
    tab.loc['Self-selectivity']         = 'Yes'
    print(tab.to_latex())
    #=== Make the plot
    colors = ['#454f83', '#18aca4', '#960018', '#d87d0b']
    markers = ['o', '*', 'o', '*']
    shifts = [-0.1,-0.05,0.05,0.1]
    count = 0
    res_bar = res_bar.replace(['DirectControl', 'DirectControl_fut_tot', 'Emission_intensity', 'Emission_intensity_fut_tot'], 
                              ['GHG Emissions', 'Future GHG emissions', 'GHG Emissions (intensity)', 'Future GHG emissions (intensity)'])
    res_bar = res_bar.rename(index = {'Total_effort': 'Total effort ',
                                      'Total_effort_risk mitigation': 'Risk Mit.(TE)',
                                      'Total_effort_stakeholders engagement': 'Stak.Eng.(TE)',
                                      'Total_effort_innovation': 'Innovation (TE)'})
    ax = plt.subplot()
    for m in res_bar['Emission type'].unique():
        tmp1 = res_bar[res_bar['Emission type'] == m]
        tmp1.index =  [1, 2, 3, 4]
        tmp1.index +=shifts[count]
        tmp1 = tmp1.rename(columns = {'Effect': m})
        tmp1[m].plot(yerr = tmp1['Uncertainty'], ls = '', 
                             ms = 15, capsize = 4, c = colors[count], marker = markers[count], ax = ax)
        count+=1
    plt.axhline(y = 0, ls ='-.', lw = 0.5, c = 'k')
    plt.ylabel('Effect on emissions,\n Standardised coefficients')
    plt.legend(loc = 'center', bbox_to_anchor = (0.5,1.12), ncol=2)
    plt.xlabel('')
    plt.margins(x=.1)
    plt.xticks(np.arange(1, 5, 1.0))
    ax.set_xticklabels(list(res_bar.iloc[:4].index))
    plt.axvline(x = 1.5, ls = '-.', c = 'k', lw = 0.5)
    
    plt.savefig('Figures/Figure4A_SI.pdf', dpi = 100, bbox_inches = 'tight')
    return res_bar, list_of_models, list_of_data, tab